import csv
import numpy as np
from os import path

class person: 
    def __init__(self,age=50,oas=0.0,gis=0.0,earn=0.0,rpp=0.0,cpp=0.0,othtax=0.0,
                 othntax=0.0,con_rrsp=0.0,inc_rrsp=0.0):
        self.age = age
        self.inc_earn = earn
        self.inc_oas = oas
        self.inc_gis = gis
        self.inc_rpp = rpp
        self.inc_cpp = cpp
        self.con_rrsp = con_rrsp
        self.inc_rrsp = inc_rrsp
        self.inc_othtax = othtax 
        self.inc_othntax = othntax
        self.inc_disp = 0.0
        return

class hhold:
    def __init__(self,first,second=None,prov='qc'):
        self.sp = [first]
        self.couple = False
        self.prov = prov
        if second!=None:
            self.sp.append(second)
            self.couple = True
        return
                 

class federal:
    def __init__(self,hhold,who,rules):
        self.hhold = hhold
        self.who = who
        self.totinc = 0.0
        self.deduc = 0.0
        self.taxinc = 0.0
        self.tax = 0.0
        self.ntcred = 0.0
        self.liab = 0.0
        self.rtcred = 0.0
        self.taxpay = 0.0
        self.dspinc = 0.0
        self.rules = rules
        return
    def file(self):
        self.calc_totinc()   
        self.calc_deduc()
        self.calc_taxinc()
        self.calc_tax()
        self.calc_ntcred()
        self.liab = self.tax - self.ntcred
        if (self.liab < 0.0):
            self.liab = 0.0
        self.calc_rtcred()
        return
    def calc_totinc(self):
        p = self.hhold.sp[self.who]
        self.totinc += p.inc_earn
        self.totinc += p.inc_oas
        self.totinc += p.inc_gis        
        self.totinc += p.inc_rpp
        self.totinc += p.inc_cpp
        self.totinc += p.inc_othtax
        self.totinc += p.inc_othntax
        self.totinc += p.inc_rrsp
        return
    def calc_deduc(self):
        p = self.hhold.sp[self.who]
        self.deduc += p.con_rrsp
        return
    def calc_taxinc(self):
        p = self.hhold.sp[self.who]
        self.taxinc = self.totinc
        self.taxinc -= p.inc_gis
        self.taxinc -= self.deduc
        return
    def calc_tax(self):
        brack = self.rules.brack
        rates = self.rules.rates
        i = self.taxinc
        t = 0.0
        g = 0.0
        for b,r in zip(brack,rates):
            if (i>b):
                t+=r*(b-g)
                g=b
            else :
                t+=r*(i-g)
                break
        self.tax = t
        return
    def calc_ntcred(self):
        self.ntcred += self.rules.nrtc_rate*self.rules.base
        self.ntcred += self.get_agecred()
        self.ntcred += self.get_pencred()
        return
    def get_agecred(self):
        p = self.hhold.sp[self.who]
        inc = self.taxinc
        nage = self.rules.nrtc_age
        nmax = self.rules.nrtc_age_max
        nbas = self.rules.nrtc_age_base
        claw = self.rules.nrtc_age_rate
        rate = self.rules.nrtc_rate
        elig = p.age>=nage
        amount = 0.0
        if elig: 
            amount += nmax
            if (inc > nbas):
                amount -= claw*(inc-nbas)     
                if (amount <0.0):
                    amount = 0.0
            amount *= rate
        return amount
    def get_pencred(self):
        amount = self.hhold.sp[self.who].inc_rpp
        if (amount > self.rules.nrtc_pension_max):
            amount = self.rules.nrtc_pension_max
        amount *= self.rules.nrtc_rate
        return amount
    def calc_rtcred(self):
        self.rtcred += self.calc_abatment()
        return
    def calc_abatment(self):
        amount = 0.0
        if (self.hhold.prov=='qc'):
            amount += (1.0-self.rules.abatment)*self.liab
        return amount
 
class quebec:
    def __init__(self,hhold,who,rules):
        self.hhold = hhold
        self.who = who
        self.totinc = 0.0
        self.deduc = 0.0
        self.taxinc = 0.0
        self.tax = 0.0
        self.ntcred = 0.0
        self.liab = 0.0
        self.rtcred = 0.0
        self.taxpay = 0.0
        self.dspinc = 0.0
        self.rules = rules
        return
    def file(self):
        self.calc_totinc()   
        self.calc_deduc()
        self.calc_taxinc()
        self.calc_tax()
        self.calc_ntcred()
        self.liab = self.tax - self.ntcred
        if (self.liab < 0.0):
            self.liab = 0.0
        self.calc_rtcred()
        return
    def calc_totinc(self):
        p = self.hhold.sp[self.who]
        self.totinc += p.inc_earn
        self.totinc += p.inc_oas
        self.totinc += p.inc_gis        
        self.totinc += p.inc_rpp
        self.totinc += p.inc_cpp
        self.totinc += p.inc_othtax
        self.totinc += p.inc_othntax
        self.totinc += p.inc_rrsp
        return
    def calc_deduc(self):
        p = self.hhold.sp[self.who]
        self.deduc += p.con_rrsp
        return
    def calc_taxinc(self):
        p = self.hhold.sp[self.who]
        self.taxinc = self.totinc
        self.taxinc -= p.inc_gis
        self.taxinc -= self.deduc
        return    
    def calc_tax(self):
        brack = self.rules.brack
        rates = self.rules.rates
        i = self.taxinc
        t = 0.0
        g = 0.0
        for b,r in zip(brack,rates):
            if (i>=b):
                t+=r*(b-g)
                g=b
            else :
                t+=r*(i-g)
                break
        self.tax = t
        return
    def calc_ntcred(self):
        self.ntcred += self.rules.nrtc_rate*self.rules.base
        self.ntcred += self.get_agecred()
        self.ntcred += self.get_pencred()
        return
    def get_agecred(self):
        p = self.hhold.sp[self.who]
        inc = self.taxinc
        nage = self.rules.nrtc_age
        nmax = self.rules.nrtc_age_max
        nbas = self.rules.nrtc_age_base
        rate = self.rules.nrtc_age_rate
        elig = p.age>=nage
        amount = 0.0
        if elig: 
            amount += nmax
        # for single
        sing = self.rules.nrtc_single
        if self.hhold.couple==False:
            elig = True
        else :
            elig = False
        if elig:
            amount += sing
        # claw both
        if (inc > nbas):
            amount -= rate*(inc-nbas)     
            if (amount <0.0):
                amount = 0.0
        return amount
    def get_pencred(self):
        amount = self.hhold.sp[self.who].inc_rpp
        if (amount > self.rules.nrtc_pension_max):
            amount = self.rules.nrtc_pension_max
        amount *= self.rules.nrtc_rate
        return amount
    def calc_rtcred(self):
        pass

class oas:
    def __init__(self,hhold,who,rules):
        self.hhold = hhold
        self.who = who
        self.p = self.hhold.sp[self.who]
        self.elig = False
        self.oasinc = 0.0
        self.rules = rules
        return
    def oaselig(self):
        if self.p.age>=self.rules.ageoas:
            self.elig = True
        return
    def file(self):
        self.oaselig()
        if (self.elig):
            self.oasinc = self.rules.oas_full
        self.clawback()
        return
    def clawback(self):
        inc = self.p.inc_earn + self.p.inc_rpp + self.p.inc_cpp + self.p.inc_othtax 
        claw = max(self.rules.oas_clawback_rate*(inc-self.rules.oas_clawback),0.0)
        if (claw > self.oasinc):
            self.oasinc = 0.0
        else :
            self.oasinc -= claw
        return
class gis:
    def __init__(self,hhold,who,rules):
        self.hhold = hhold
        self.who = who
        self.p = self.hhold.sp[self.who]
        self.elig = False
        self.gisinc = 0.0
        self.rules = rules
        return
    def giselig(self):
        if self.p.age>=self.rules.agegis:
            self.elig = True
        return
    def file(self):
        self.giselig()
        if (self.elig):
            if self.hhold.couple :
                self.gisinc = self.rules.gis_full_couple
            else :
                self.gisinc = self.rules.gis_full_single
            self.clawback()
            self.gisinc += self.get_bonus()
        return
    def clawback(self):
        if self.hhold.couple:
            rate = self.rules.gis_reduct_rate_couple
            inc = 0.0
            for s in self.hhold.sp:
                inc += s.inc_earn + s.inc_rpp + s.inc_cpp + s.inc_othtax 
        else :
            rate = self.rules.gis_reduct_rate_single
            inc = self.p.inc_earn + self.p.inc_rpp + self.p.inc_cpp + self.p.inc_othtax 
        claw = max(rate*(inc-self.rules.gis_work_exemption),0.0)
        if (claw > self.gisinc):
            self.gisinc = 0.0
        else :
            self.gisinc -= claw
        return
    def get_bonus(self):
        amount = 0.0
        if self.hhold.couple==False:
            p = self.hhold.sp[self.who]
            amount += self.rules.gis_bonus_single
            inc = self.p.inc_earn + self.p.inc_rpp + self.p.inc_cpp + self.p.inc_othtax
            claw = max(self.rules.gis_bonus_reduct_single*(inc-self.rules.gis_bonus_exemption_single),0.0) 
            amount -= claw
            if (amount <0.0):
                amount = 0.0
        return amount
class fedpars:
    def __init__(self,year):
        self.year = year
        pars = path.join(path.dirname(__file__), 'pars/')
        self.path = pars+str(self.year)+'/fed/'
        return
    def loadpars(self):
        with open(self.path+'values.csv') as csvfile:
            reader = csv.reader(csvfile,delimiter=';')
            for i,row in enumerate(reader):
                d = [float(r) for r in row[1:] if r!='']
                if (len(d)==1):
                    d = d[0]
                self.__setattr__(row[0],d)
        return
class qcpars:
    def __init__(self,year):
        self.year = year
        pars = path.join(path.dirname(__file__), 'pars/')
        self.path = pars+str(self.year)+'/qc/'
        return
    def loadpars(self):
        with open(self.path+'values.csv') as csvfile:
            reader = csv.reader(csvfile,delimiter=';')
            for i,row in enumerate(reader):
                d = [float(r) for r in row[1:] if r!='']
                if (len(d)==1):
                    d = d[0]
                self.__setattr__(row[0],d)
        return
class oaspars:
    def __init__(self,year):
        self.year = year
        pars = path.join(path.dirname(__file__), 'pars/')
        self.path = pars+str(self.year)+'/oas/'
        return
    def loadpars(self):
        with open(self.path+'values.csv') as csvfile:
            reader = csv.reader(csvfile,delimiter=';')
            for i,row in enumerate(reader):
                d = [float(r) for r in row[1:] if r!='']
                if (len(d)==1):
                    d = d[0]
                self.__setattr__(row[0],d)
        return
class gispars:
    def __init__(self,year):
        self.year = year
        pars = path.join(path.dirname(__file__), 'pars/')
        self.path = pars+str(self.year)+'/gis/'
        return
    def loadpars(self):
        with open(self.path+'values.csv') as csvfile:
            reader = csv.reader(csvfile,delimiter=';')
            for i,row in enumerate(reader):
                d = [float(r) for r in row[1:] if r!='']
                if (len(d)==1):
                    d = d[0]
                self.__setattr__(row[0],d)
        return
  
class tax:
    def __init__(self,year=2016,prov='qc'):
        self.year = year
        self.prov = prov
        self.loadfederal()
        self.loadprovincial()
        self.loadoas()
        self.loadgis()
        return
    def loadfederal(self):
        self.fed = fedpars(self.year)      
        self.fed.loadpars()
        return
    def loadprovincial(self):
        if (self.prov=='qc'):     
            self.pro = qcpars(self.year)      
            self.pro.loadpars()
        return
    def loadoas(self):
        self.oasp = oaspars(self.year)      
        self.oasp.loadpars()
        return
    def loadgis(self):
        self.gisp = gispars(self.year)      
        self.gisp.loadpars()
        return
    def file(self,hhold):
        # get oas
        self.fileoas(hhold)
        # get gis
        self.filegis(hhold)        
        # federal taxes
        self.filefed(hhold)
        # provincial taxes
        self.filepro(hhold)
        return
    def fileoas(self,hhold):
        for i,s in enumerate(hhold.sp):
            form = oas(hhold,i,self.oasp)
            form.file()
            hhold.sp[i].inc_oas = form.oasinc
        return
    def filegis(self,hhold):
        for i,s in enumerate(hhold.sp):
            form = gis(hhold,i,self.gisp)
            form.file()
            hhold.sp[i].inc_gis = form.gisinc
        return
    def filefed(self,hhold):
        self.fedforms = []
        for i,sp in enumerate(hhold.sp):
            self.fedforms.append(federal(hhold,i,self.fed))
            self.fedforms[i].file()
        return
    def filepro(self,hhold):
        if (self.prov=='qc'):
            self.proforms = []
            for i,sp in enumerate(hhold.sp):
                self.proforms.append(quebec(hhold,i,self.pro))
                self.proforms[i].file()
        return        
    def paftertax(self,hhold,who):
        inc_disp = 0.0
        inc_disp += self.fedforms[who].totinc
        inc_disp -= self.fedforms[who].liab
        inc_disp -= self.proforms[who].liab
        inc_disp += self.fedforms[who].rtcred
        inc_disp += self.proforms[who].rtcred
        return inc_disp
    def haftertax(self,hhold):
        hh = 0.0
        for i,p in enumerate(hhold.sp):
            hh += self.paftertax(hhold,i)
        return hh
    def pinc(self,hhold,who):
        return self.fedforms[who].totinc
    def hinc(self,hhold):
        hh = 0.0
        for i,p in enumerate(hhold.sp):
            hh += self.pinc(hhold,i)
        return hh
    def patr(self,hhold,who):
        rate = 1.0-self.paftertax(hhold,who)/self.pinc(hhold,who)
        return rate
    def pmtr(self,hhold,who,incre):
        self.file(hhold)
        inc   = self.paftertax(hhold,who)
        hholdp = hhold
        hholdp.sp[who].inc_earn += incre
        self.file(hholdp)
        incp    = self.paftertax(hhold,who)
        rate = 1.0 - (incp - inc)/incre
        return rate
    def hatr(self,hhold):
        rate = 1.0-self.haftertax(hhold)/self.hinc(hhold)
        return rate
    def hmtr(self,hhold,who,incre):
        self.file(hhold)
        inc   = self.haftertax(hhold)
        hholdp = hhold
        hholdp.sp[who].inc_earn += incre
        self.file(hholdp)
        incp    = self.haftertax(hhold)
        rate = 1.0 - (incp - inc)/incre
        return rate      
     